import numpy as np
from math import *
from env import single_expert_dynamics,single_expert_stochastic_dynamics, expert1_reward, expert2_reward, expert3_reward, expert1_cost, expert2_cost, expert3_cost, feature1, feature2, feature3, expert_1_basis_constraint, expert_2_basis_constraint, expert_3_basis_constraint
from multiprocessing import Process

def empirical_feature_counts(trajectories,num_data):
  counts=0
  for i in range(num_data):
    one_count=0
    single_trajectory=trajectories[30*i:30*(i+1),:]
    for j in range(30):
      state=np.mat(np.copy(single_trajectory[j][0:6])).T
      action=np.mat(np.copy(single_trajectory[j][6:9])).T
      expert1_feature=feature1(state[0:2],action[0])
      expert2_feature=feature2(state[2:4],action[1])
      expert3_feature=feature3(state[4:6],action[2])
      one_count=one_count+np.vstack((expert1_feature,expert2_feature,expert3_feature))
    counts=counts+one_count
  return counts/num_data

def empirical_cost_counts(trajectories,num_data,theta):
  theta1=theta[0:8]
  theta2=theta[8:16]
  theta3=theta[16:26]
  counts=0
  for i in range(num_data):
    one_count=0
    single_trajectory=trajectories[30*i:30*(i+1),:]
    for j in range(30):
      state=np.mat(np.copy(single_trajectory[j][0:6])).T
      action=np.mat(np.copy(single_trajectory[j][6:9])).T
      cost=expert1_cost(theta1,state[0:2],action[0])+expert2_cost(theta2,state[2:4],action[1])+expert3_cost(theta3,state[4:6],action[2])
      one_count=one_count+cost
    counts=counts+one_count
  return counts/num_data

def empirical_constraint_counts(trajectories,num_data,lam):
  counts=0
  for i in range(num_data):
    one_count=0
    single_trajectory=trajectories[30*i:30*(i+1),:]
    for j in range(30):
      state=np.mat(np.copy(single_trajectory[j][0:6])).T
      action=np.mat(np.copy(single_trajectory[j][6:9])).T
      constraint1=expert_1_basis_constraint(state[0:2],action[0])
      constraint2=expert_2_basis_constraint(state[2:4],action[1])
      constraint3=expert_3_basis_constraint(state[4:6],action[2])
      constraint=np.vstack((constraint1,constraint2,constraint3))
      one_count=one_count+lam*constraint
    counts=counts+one_count
  return counts/num_data


def choose_action(policy_distribution):  # distribution is 9x1
  choice=np.random.uniform()
  sum_value=0.0
  for a in range(num_action):
    sum_value=sum_value+policy_distribution[a]
    if sum_value>=choice:
      return a

def trial(initial_state,policy1,policy2,policy3,num_action):
  trajectory=[]
  state=initial_state
  for i in range(30):
    policy1_distribution=policy1[state.item(0)][state.item(1)][:]
    action1=choose_action(policy1_distribution)
    next_state1=single_expert_stochastic_dynamics(state[0:2],np.mat([action1]).T)
    policy2_distribution=policy2[state.item(2)][state.item(3)][:]
    action2=choose_action(policy2_distribution)
    next_state2=single_expert_stochastic_dynamics(state[2:4],np.mat([action2]).T)
    policy3_distribution=policy3[state.item(4)][state.item(5)][:]
    action3=choose_action(policy3_distribution)
    next_state3=single_expert_stochastic_dynamics(state[4:6],np.mat([action3]).T)
    trajectory.append([state.item(0),state.item(1),state.item(2),state.item(3),state.item(4),state.item(5),action1,action2,action3])
    state=np.copy(np.vstack((next_state1,next_state2,next_state3)))
  return trajectory

def soft_policy(Q_matrix,V_matrix,num_action):
  distribution=np.zeros((9,9,num_action))
  distribution=distribution.astype(np.object)
  for x in range(9):
    for y in range(9):
      for a in range(num_action):
        distribution[x][y][a]=exp(Q_matrix[x][y][a])/exp(V_matrix[x][y])
  return distribution

def soft_Q_matrix_function(gamma,reward_matrix,lam,cost_matrix,V_matrix,num_action):
  Q_matrix=np.zeros((9,9,num_action))
  Q_matrix=Q_matrix.astype(np.object)
  for x in range(9):
    for y in range(9):
      for a in range(num_action):
        next_state=single_expert_dynamics(np.mat([x,y]).T,np.mat([a]).T)
        value=0.8*V_matrix[next_state.item(0)][next_state.item(1)]+0.2*V_matrix[x][y]
        Q_matrix[x][y][a]=reward_matrix[x,y,a]-lam*cost_matrix[x,y,a]+gamma*value
  return Q_matrix

  
def soft_V_matrix_funciton(Q_matrix,num_action):
  V_matrix=np.zeros((9,9))
  V_matrix=V_matrix.astype(np.object)
  for x in range(9):
    for y in range(9):
      value=0.0
      for a in range(num_action):
        value=value+exp(Q_matrix[x][y][a])
      V_matrix[x][y]=log(value)
  return V_matrix

def calculate_soft_policy(omega,theta,lam,gamma,num_action):

  reward1_matrix=np.zeros((9,9,9))
  reward2_matrix=np.zeros((9,9,9))
  reward3_matrix=np.zeros((9,9,9))
  reward1_matrix=reward1_matrix.astype(np.object)
  reward2_matrix=reward2_matrix.astype(np.object)
  reward3_matrix=reward3_matrix.astype(np.object)

  cost1_matrix=np.zeros((9,9,9))
  cost2_matrix=np.zeros((9,9,9))
  cost3_matrix=np.zeros((9,9,9))
  cost1_matrix=cost1_matrix.astype(np.object)
  cost2_matrix=cost2_matrix.astype(np.object)
  cost3_matrix=cost3_matrix.astype(np.object)

  for x in range(9):
    for y in range(9):
      for a in range(9):
        reward1_matrix[x,y,a]=expert1_reward(omega[0:2],np.mat([x,y]).T,np.mat([a]).T)
        reward2_matrix[x,y,a]=expert2_reward(omega[2:4],np.mat([x,y]).T,np.mat([a]).T)
        reward3_matrix[x,y,a]=expert3_reward(omega[4:6],np.mat([x,y]).T,np.mat([a]).T)
        cost1_matrix[x,y,a]=expert1_cost(theta[0:8],np.mat([x,y]).T,np.mat([a]).T)
        cost2_matrix[x,y,a]=expert2_cost(theta[8:16],np.mat([x,y]).T,np.mat([a]).T)
        cost3_matrix[x,y,a]=expert3_cost(theta[16:26],np.mat([x,y]).T,np.mat([a]).T)

  soft_V1_matrix=np.zeros((9,9))
  soft_V1_matrix=soft_V1_matrix.astype(np.object)
  soft_Q1_matrix=np.copy(soft_Q_matrix_function(gamma,reward1_matrix,lam,cost1_matrix,soft_V1_matrix,num_action))
  new_soft_V1_matrix=np.copy(soft_V_matrix_funciton(soft_Q1_matrix,num_action))
  soft_V2_matrix=np.zeros((9,9))
  soft_V2_matrix=soft_V2_matrix.astype(np.object)
  soft_Q2_matrix=np.copy(soft_Q_matrix_function(gamma,reward2_matrix,lam,cost2_matrix,soft_V2_matrix,num_action))
  new_soft_V2_matrix=np.copy(soft_V_matrix_funciton(soft_Q2_matrix,num_action))
  soft_V3_matrix=np.zeros((9,9))
  soft_V3_matrix=soft_V3_matrix.astype(np.object)
  soft_Q3_matrix=np.copy(soft_Q_matrix_function(gamma,reward3_matrix,lam,cost3_matrix,soft_V3_matrix,num_action))
  new_soft_V3_matrix=np.copy(soft_V_matrix_funciton(soft_Q3_matrix,num_action))
  #saved_Q_matrix=soft_Q_matrix.reshape(9*9*9,9*9*9)
  #soft_Q_file=open("soft_Q_file.txt","w")
  #for entry in saved_Q_matrix:
  #  np.savetxt(soft_Q_file,entry)
  #soft_Q_file.close()
  #saved_V_matrix=new_soft_V_matrix.reshape(9*9,9*9)
  #soft_V_file=open("soft_V_file.txt","w")
  #for entry in saved_V_matrix:
  #  np.savetxt(soft_V_file,entry)
  #soft_V_file.close()

  #load_soft_V_matrix=np.loadtxt("soft_V_file.txt",dtype=float)
  #new_soft_V_matrix=load_soft_V_matrix.reshape(9,9,9,9)
  #new_soft_V_matrix=new_soft_V_matrix.astype(np.object)
  max_value1=0.0
  max_value2=0.0
  max_value3=0.0
  for x in range(9):
    for y in range(9):
      if max_value1<abs(soft_V1_matrix[x][y]-new_soft_V1_matrix[x][y]):
        max_value1=abs(soft_V1_matrix[x][y]-new_soft_V1_matrix[x][y])
      if max_value2<abs(soft_V2_matrix[x][y]-new_soft_V2_matrix[x][y]):
        max_value2=abs(soft_V2_matrix[x][y]-new_soft_V2_matrix[x][y])
      if max_value3<abs(soft_V3_matrix[x][y]-new_soft_V3_matrix[x][y]):
        max_value3=abs(soft_V3_matrix[x][y]-new_soft_V3_matrix[x][y])
  while max_value1>1.0 or max_value2>1.0 or max_value3>1.0:
    #print(max_value3)
    soft_V1_matrix=np.copy(new_soft_V1_matrix)
    soft_Q1_matrix=np.copy(soft_Q_matrix_function(gamma,reward1_matrix,lam,cost1_matrix,soft_V1_matrix,num_action))
    new_soft_V1_matrix=np.copy(soft_V_matrix_funciton(soft_Q1_matrix,num_action))
    soft_V2_matrix=np.copy(new_soft_V2_matrix)
    soft_Q2_matrix=np.copy(soft_Q_matrix_function(gamma,reward2_matrix,lam,cost2_matrix,soft_V2_matrix,num_action))
    new_soft_V2_matrix=np.copy(soft_V_matrix_funciton(soft_Q2_matrix,num_action))
    soft_V3_matrix=np.copy(new_soft_V3_matrix)
    soft_Q3_matrix=np.copy(soft_Q_matrix_function(gamma,reward3_matrix,lam,cost3_matrix,soft_V3_matrix,num_action))
    new_soft_V3_matrix=np.copy(soft_V_matrix_funciton(soft_Q3_matrix,num_action))
    max_value1=0.0
    max_value2=0.0
    max_value3=0.0
    for x in range(9):
      for y in range(9):
        if max_value1<abs(soft_V1_matrix[x][y]-new_soft_V1_matrix[x][y]):
          max_value1=abs(soft_V1_matrix[x][y]-new_soft_V1_matrix[x][y])
        if max_value2<abs(soft_V2_matrix[x][y]-new_soft_V2_matrix[x][y]):
          max_value2=abs(soft_V2_matrix[x][y]-new_soft_V2_matrix[x][y])
        if max_value3<abs(soft_V3_matrix[x][y]-new_soft_V3_matrix[x][y]):
          max_value3=abs(soft_V3_matrix[x][y]-new_soft_V3_matrix[x][y])
  policy1=np.copy(soft_policy(soft_Q1_matrix,new_soft_V1_matrix,num_action))
  policy2=np.copy(soft_policy(soft_Q2_matrix,new_soft_V2_matrix,num_action))
  policy3=np.copy(soft_policy(soft_Q3_matrix,new_soft_V3_matrix,num_action))
  return policy1, policy2, policy3



def feature_expectation(policy1,policy2,policy3,initial_state,num_trials):
  trajectories=np.zeros((0,9))
  for i in range(num_trials):
    trajectory=np.copy(trial(initial_state,policy1,policy2,policy3,num_action))
    trajectory_array=np.zeros((30,9))
    for j in range(30):
     trajectory_array[j,:]=np.copy(trajectory[j])
    trajectories=np.vstack((trajectories,trajectory_array))
  return empirical_feature_counts(trajectories,num_trials)

def cost_expectation(policy1,policy2,policy3,initial_state,num_trials,theta):
  trajectories=np.zeros((0,9))
  for i in range(num_trials):
    trajectory=np.copy(trial(initial_state,policy1,policy2,policy3,num_action))
    trajectory_array=np.zeros((30,9))
    for j in range(30):
     trajectory_array[j,:]=np.copy(trajectory[j])
    trajectories=np.vstack((trajectories,trajectory_array))
  return empirical_cost_counts(trajectories,num_trials,theta)

def reward_cost_list(trajectories,num_data):
  omega1=np.mat([1.0,-1.0]).T
  omega2=np.mat([1.0,-1.0]).T
  omega3=np.mat([1.0,-1.0]).T
  theta1=np.mat([1.0,1.0,1.0,1.0,0.0,0.0,1.0,1.0]).T
  theta2=np.mat([1.0,1.0,1.0,1.0,0.0,0.0,1.0,1.0]).T
  theta3=np.mat([1.0,1.0,1.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0]).T
  reward_list=[]
  cost_list=[]
  for i in range(num_data):
    reward=0.0
    cost=0.0
    single_trajectory=trajectories[30*i:30*(i+1),:]
    for j in range(30):
      state1=np.mat(np.copy(single_trajectory[j][0:2])).T
      state2=np.mat(np.copy(single_trajectory[j][2:4])).T
      state3=np.mat(np.copy(single_trajectory[j][4:6])).T
      action1=np.mat(np.copy(single_trajectory[j][6])).T
      action2=np.mat(np.copy(single_trajectory[j][7])).T
      action3=np.mat(np.copy(single_trajectory[j][8])).T
      single_reward=expert1_reward(omega1,state1,action1)+expert2_reward(omega2,state2,action2)+expert3_reward(omega3,state3,action3)
      single_cost=expert1_cost(theta1,state1,action1)+expert2_cost(theta2,state2,action2)+expert3_cost(theta3,state3,action3)
      reward=reward+single_reward
      cost=cost+single_cost
    reward_list.append(reward)
    cost_list.append(cost)
  return reward_list, cost_list
  

def feature_cost_mean_sd(policy1,policy2,policy3,initial_state,num_trials,theta):
  trajectories=np.zeros((0,9))
  for i in range(num_trials):
    trajectory=np.copy(trial(initial_state,policy1,policy2,policy3,num_action))
    trajectory_array=np.zeros((30,9))
    for j in range(30):
     trajectory_array[j,:]=np.copy(trajectory[j])
    trajectories=np.vstack((trajectories,trajectory_array))
  feature_expectation=empirical_feature_counts(trajectories,num_trials)
  reward_list, cost_list=reward_cost_list(trajectories,num_trials)
  reward_mean=sum(reward_list)/len(reward_list)
  reward_sd=sqrt(np.var(reward_list))
  cost_mean=sum(cost_list)/len(cost_list)
  cost_sd=sqrt(np.var(cost_list))
  return feature_expectation, reward_mean, reward_sd, cost_mean, cost_sd

def KL(policy1,policy2,policy3,expert_policy1,expert_policy2,expert_policy3):
  divergence=0
  for x in range(9):
    for y in range(9):
      for a in range(9):
        if expert_policy1[x,y,a]!=0 and policy1[x,y,a]!=0:
          divergence=divergence+expert_policy1[x,y,a]*log(expert_policy1[x,y,a]/policy1[x,y,a])
        if expert_policy2[x,y,a]!=0 and policy2[x,y,a]!=0:
          divergence=divergence+expert_policy2[x,y,a]*log(expert_policy2[x,y,a]/policy2[x,y,a])
        if expert_policy3[x,y,a]!=0 and policy3[x,y,a]!=0:
          divergence=divergence+expert_policy3[x,y,a]*log(expert_policy3[x,y,a]/policy3[x,y,a])
  return divergence/(9*9*3)

def constraint_expectation(policy1,policy2,policy3,initial_state,num_trials,lam):
  trajectories=np.zeros((0,9))
  for i in range(num_trials):
    trajectory=np.copy(trial(initial_state,policy1,policy2,policy3,num_action))
    trajectory_array=np.zeros((30,9))
    for j in range(30):
     trajectory_array[j,:]=np.copy(trajectory[j])
    trajectories=np.vstack((trajectories,trajectory_array))
  return empirical_constraint_counts(trajectories,num_trials,lam)

def inner_loop(initial_state,omega1,omega2,omega3,omega4,lam1,lam2,lam3,lam4,theta,gamma,num_action,num_trials,cost_empirical1,cost_empirical2,cost_empirical3,cost_empirical4, feature_empirical1,feature_empirical2,feature_empirical3,feature_empirical4,num_data1,num_data2,num_data3,num_data4):
  i=0
  iterations=10
  while i<iterations:
    print('iteration={}' .format(i))
    policy1=calculate_soft_policy(omega1,theta,lam1,gamma,num_action)
    policy2=calculate_soft_policy(omega2,theta,lam2,gamma,num_action)
    policy3=calculate_soft_policy(omega3,theta,lam3,gamma,num_action)
    policy4=calculate_soft_policy(omega4,theta,lam4,gamma,num_action)
    feature1=feature_expectation(policy1,initial_state,num_trials)
    feature2=feature_expectation(policy2,initial_state,num_trials)
    feature3=feature_expectation(policy3,initial_state,num_trials)
    feature4=feature_expectation(policy4,initial_state,num_trials)
    cost1=cost_expectation(policy1,initial_state,num_trials,theta)
    cost2=cost_expectation(policy2,initial_state,num_trials,theta)
    cost3=cost_expectation(policy3,initial_state,num_trials,theta)
    cost4=cost_expectation(policy4,initial_state,num_trials,theta)
    feature_gradient1=feature1-feature_empirical1
    feature_gradient2=feature2-feature_empirical2
    feature_gradient3=feature3-feature_empirical3
    feature_gradient4=feature4-feature_empirical4
    cost_gradient1=cost1-cost_empirical1
    cost_gradient2=cost2-cost_empirical2
    cost_gradient3=cost3-cost_empirical3
    cost_gradient4=cost4-cost_empirical4

    print('feature_gradient1 is {}' .format(feature_gradient1))
    print('feature_gradient2 is {}' .format(feature_gradient2))
    print('feature_gradient3 is {}' .format(feature_gradient3))
    print('feature_gradient4 is {}' .format(feature_gradient4))

    print('cost_gradient1 is {}' .format(cost_gradient1))
    print('cost_gradient2 is {}' .format(cost_gradient2))
    print('cost_gradient3 is {}' .format(cost_gradient3))
    print('cost_gradient4 is {}' .format(cost_gradient4))

    if i%2==1:
      omega1=0.5*(omega1+omega2)-(1.0/12000)*num_data1*feature_gradient1
      omega2=0.5*(omega1+omega2)-(1.0/12000)*num_data2*feature_gradient2
      omega3=0.5*(omega3+omega4)-(1.0/12000)*num_data3*feature_gradient3
      omega4=0.5*(omega3+omega4)-(1.0/12000)*num_data4*feature_gradient4
      lam1=0.5*(lam1+lam2)-(1.0/12000)*num_data1*cost_gradient1
      lam2=0.5*(lam1+lam2)-(1.0/12000)*num_data2*cost_gradient2
      lam3=0.5*(lam3+lam4)-(1.0/12000)*num_data3*cost_gradient3
      lam4=0.5*(lam3+lam4)-(1.0/12000)*num_data4*cost_gradient4

    else:
      omega1=0.5*(omega1+omega4)-(1.0/12000)*num_data1*feature_gradient1
      omega2=0.5*(omega2+omega3)-(1.0/12000)*num_data2*feature_gradient2
      omega3=0.5*(omega2+omega3)-(1.0/12000)*num_data3*feature_gradient3
      omega4=0.5*(omega1+omega4)-(1.0/12000)*num_data4*feature_gradient4
      lam1=0.5*(lam1+lam4)-(1.0/12000)*num_data1*cost_gradient1
      lam2=0.5*(lam2+lam3)-(1.0/12000)*num_data2*cost_gradient2
      lam3=0.5*(lam2+lam3)-(1.0/12000)*num_data3*cost_gradient3
      lam4=0.5*(lam1+lam4)-(1.0/12000)*num_data4*cost_gradient4

    print('omega1 is {}' .format(omega1))
    print('omega2 is {}' .format(omega2))
    print('omega3 is {}' .format(omega3))
    print('omega4 is {}' .format(omega4))

    print('lam1 is {}' .format(lam1))
    print('lam2 is {}' .format(lam2))
    print('lam3 is {}' .format(lam3))
    print('lam4 is {}' .format(lam4))

  i=i+1

def last_interation_inner_loop(initial_state,omega1,omega2,omega3,omega4,lam1,lam2,lam3,lam4,theta,gamma,num_action,num_trials,cost_empirical1,cost_empirical2,cost_empirical3,cost_empirical4, feature_empirical1,feature_empirical2,feature_empirical3,feature_empirical4,num_data1,num_data2,num_data3,num_data4):
  distribution1=np.loadtxt("optimal_policy1_file.txt",dtype=float)
  expert_policy1=distribution1.reshape(9,9,num_action)
  distribution2=np.loadtxt("optimal_policy2_file.txt",dtype=float)
  expert_policy2=distribution2.reshape(9,9,num_action)
  distribution3=np.loadtxt("optimal_policy3_file.txt",dtype=float)
  expert_policy3=distribution3.reshape(9,9,num_action)

  i=0
  iterations=121
  
  divergence1_list=np.zeros((iterations,1))
  reward1_mean_list=np.zeros((iterations,1))
  reward1_sd_list=np.zeros((iterations,1))
  cost1_mean_list=np.zeros((iterations,1))
  cost1_sd_list=np.zeros((iterations,1))

  divergence2_list=np.zeros((iterations,1))
  reward2_mean_list=np.zeros((iterations,1))
  reward2_sd_list=np.zeros((iterations,1))
  cost2_mean_list=np.zeros((iterations,1))
  cost2_sd_list=np.zeros((iterations,1))

  divergence3_list=np.zeros((iterations,1))
  reward3_mean_list=np.zeros((iterations,1))
  reward3_sd_list=np.zeros((iterations,1))
  cost3_mean_list=np.zeros((iterations,1))
  cost3_sd_list=np.zeros((iterations,1))

  divergence4_list=np.zeros((iterations,1))
  reward4_mean_list=np.zeros((iterations,1))
  reward4_sd_list=np.zeros((iterations,1))
  cost4_mean_list=np.zeros((iterations,1))
  cost4_sd_list=np.zeros((iterations,1))

  while i<iterations:
    print('iteration={}' .format(i))
    policy11, policy12, policy13=calculate_soft_policy(omega1,theta,lam1,gamma,num_action)
    policy21, policy22, policy23=calculate_soft_policy(omega2,theta,lam2,gamma,num_action)
    policy31, policy32, policy33=calculate_soft_policy(omega3,theta,lam3,gamma,num_action)
    policy41, policy42, policy43=calculate_soft_policy(omega4,theta,lam4,gamma,num_action)
    feature1, reward_mean1, reward_sd1, cost1, cost_sd1=feature_cost_mean_sd(policy11,policy12,policy13,initial_state,num_trials,theta)
    feature2, reward_mean2, reward_sd2, cost2, cost_sd2=feature_cost_mean_sd(policy21,policy22,policy23,initial_state,num_trials,theta)
    feature3, reward_mean3, reward_sd3, cost3, cost_sd3=feature_cost_mean_sd(policy31,policy32,policy33,initial_state,num_trials,theta)
    feature4, reward_mean4, reward_sd4, cost4, cost_sd4=feature_cost_mean_sd(policy41,policy42,policy43,initial_state,num_trials,theta)
    divergence1=KL(policy11,policy12,policy13,expert_policy1,expert_policy2,expert_policy3)
    divergence2=KL(policy21,policy22,policy23,expert_policy1,expert_policy2,expert_policy3)
    divergence3=KL(policy31,policy32,policy33,expert_policy1,expert_policy2,expert_policy3)
    divergence4=KL(policy41,policy42,policy43,expert_policy1,expert_policy2,expert_policy3)

    divergence1_list[i]=divergence1
    reward1_mean_list[i]=reward_mean1
    reward1_sd_list[i]=reward_sd1
    cost1_mean_list[i]=cost1
    cost1_sd_list[i]=cost_sd1

    divergence2_list[i]=divergence2
    reward2_mean_list[i]=reward_mean2
    reward2_sd_list[i]=reward_sd2
    cost2_mean_list[i]=cost2
    cost2_sd_list[i]=cost_sd2

    divergence3_list[i]=divergence3
    reward3_mean_list[i]=reward_mean3
    reward3_sd_list[i]=reward_sd3
    cost3_mean_list[i]=cost3
    cost3_sd_list[i]=cost_sd3

    divergence4_list[i]=divergence4
    reward4_mean_list[i]=reward_mean4
    reward4_sd_list[i]=reward_sd4
    cost4_mean_list[i]=cost4
    cost4_sd_list[i]=cost_sd4

    feature_gradient1=feature1-feature_empirical1
    feature_gradient2=feature2-feature_empirical2
    feature_gradient3=feature3-feature_empirical3
    feature_gradient4=feature4-feature_empirical4
    cost_gradient1=cost1-cost_empirical1
    cost_gradient2=cost2-cost_empirical2
    cost_gradient3=cost3-cost_empirical3
    cost_gradient4=cost4-cost_empirical4

    print('divergence1 is {}' .format(divergence1))
    print('divergence2 is {}' .format(divergence2))
    print('divergence3 is {}' .format(divergence3))
    print('divergence4 is {}' .format(divergence4))

    print('reward1 is {}' .format(reward_mean1))
    print('reward2 is {}' .format(reward_mean2))
    print('reward3 is {}' .format(reward_mean3))
    print('reward4 is {}' .format(reward_mean4))

    print('cost1 is {}' .format(cost1))
    print('cost2 is {}' .format(cost2))
    print('cost3 is {}' .format(cost3))
    print('cost4 is {}' .format(cost4))

    print('feature_gradient1 is {}' .format(feature_gradient1))
    print('feature_gradient2 is {}' .format(feature_gradient2))
    print('feature_gradient3 is {}' .format(feature_gradient3))
    print('feature_gradient4 is {}' .format(feature_gradient4))

    print('cost_gradient1 is {}' .format(cost_gradient1))
    print('cost_gradient2 is {}' .format(cost_gradient2))
    print('cost_gradient3 is {}' .format(cost_gradient3))
    print('cost_gradient4 is {}' .format(cost_gradient4))

    if i%2==1:
      omega1=0.5*(omega1+omega2)-(1.0/2000000)*num_data1*feature_gradient1
      omega2=0.5*(omega1+omega2)-(1.0/2000000)*num_data2*feature_gradient2
      omega3=0.5*(omega3+omega4)-(1.0/2000000)*num_data3*feature_gradient3
      omega4=0.5*(omega3+omega4)-(1.0/2000000)*num_data4*feature_gradient4
      lam1=0.5*(lam1+lam2)+(1.0/2000000)*num_data1*cost_gradient1
      lam2=0.5*(lam1+lam2)+(1.0/2000000)*num_data2*cost_gradient2
      lam3=0.5*(lam3+lam4)+(1.0/2000000)*num_data3*cost_gradient3
      lam4=0.5*(lam3+lam4)+(1.0/2000000)*num_data4*cost_gradient4

    else:
      omega1=0.5*(omega1+omega4)-(1.0/2000000)*num_data1*feature_gradient1
      omega2=0.5*(omega2+omega3)-(1.0/2000000)*num_data2*feature_gradient2
      omega3=0.5*(omega2+omega3)-(1.0/2000000)*num_data3*feature_gradient3
      omega4=0.5*(omega1+omega4)-(1.0/2000000)*num_data4*feature_gradient4
      lam1=0.5*(lam1+lam4)+(1.0/2000000)*num_data1*cost_gradient1
      lam2=0.5*(lam2+lam3)+(1.0/2000000)*num_data2*cost_gradient2
      lam3=0.5*(lam2+lam3)+(1.0/2000000)*num_data3*cost_gradient3
      lam4=0.5*(lam1+lam4)+(1.0/2000000)*num_data4*cost_gradient4

    print('omega1 is {}' .format(omega1))
    print('omega2 is {}' .format(omega2))
    print('omega3 is {}' .format(omega3))
    print('omega4 is {}' .format(omega4))

    print('lam1 is {}' .format(lam1))
    print('lam2 is {}' .format(lam2))
    print('lam3 is {}' .format(lam3))
    print('lam4 is {}' .format(lam4))
 
    i=i+1

  divergence1_file=open("divergence1_file.txt","w")
  for entry in divergence1_list:
    np.savetxt(divergence1_file,entry)
  divergence1_file.close()
  divergence2_file=open("divergence2_file.txt","w")
  for entry in divergence2_list:
    np.savetxt(divergence2_file,entry)
  divergence2_file.close()
  divergence3_file=open("divergence3_file.txt","w")
  for entry in divergence3_list:
    np.savetxt(divergence3_file,entry)
  divergence3_file.close()
  divergence4_file=open("divergence4_file.txt","w")
  for entry in divergence4_list:
    np.savetxt(divergence4_file,entry)
  divergence4_file.close()

  reward1_mean_file=open("reward1_mean_file.txt","w")
  for entry in reward1_mean_list:
    np.savetxt(reward1_mean_file,entry)
  reward1_mean_file.close()
  reward2_mean_file=open("reward2_mean_file.txt","w")
  for entry in reward2_mean_list:
    np.savetxt(reward2_mean_file,entry)
  reward2_mean_file.close()
  reward3_mean_file=open("reward3_mean_file.txt","w")
  for entry in reward3_mean_list:
    np.savetxt(reward3_mean_file,entry)
  reward3_mean_file.close()
  reward4_mean_file=open("reward4_mean_file.txt","w")
  for entry in reward4_mean_list:
    np.savetxt(reward4_mean_file,entry)
  reward4_mean_file.close()

  reward1_sd_file=open("reward1_sd_file.txt","w")
  for entry in reward1_sd_list:
    np.savetxt(reward1_sd_file,entry)
  reward1_sd_file.close()
  reward2_sd_file=open("reward2_sd_file.txt","w")
  for entry in reward2_sd_list:
    np.savetxt(reward2_sd_file,entry)
  reward2_sd_file.close()
  reward3_sd_file=open("reward3_sd_file.txt","w")
  for entry in reward3_sd_list:
    np.savetxt(reward3_sd_file,entry)
  reward3_sd_file.close()
  reward4_sd_file=open("reward4_sd_file.txt","w")
  for entry in reward4_sd_list:
    np.savetxt(reward4_sd_file,entry)
  reward4_sd_file.close()

  cost1_mean_file=open("cost1_mean_file.txt","w")
  for entry in cost1_mean_list:
    np.savetxt(cost1_mean_file,entry)
  cost1_mean_file.close()
  cost2_mean_file=open("cost2_mean_file.txt","w")
  for entry in cost2_mean_list:
    np.savetxt(cost2_mean_file,entry)
  cost2_mean_file.close()
  cost3_mean_file=open("cost3_mean_file.txt","w")
  for entry in cost3_mean_list:
    np.savetxt(cost3_mean_file,entry)
  cost3_mean_file.close()
  cost4_mean_file=open("cost4_mean_file.txt","w")
  for entry in cost4_mean_list:
    np.savetxt(cost4_mean_file,entry)
  cost4_mean_file.close()

  cost1_sd_file=open("cost1_sd_file.txt","w")
  for entry in cost1_sd_list:
    np.savetxt(cost1_sd_file,entry)
  cost1_sd_file.close()
  cost2_sd_file=open("cost2_sd_file.txt","w")
  for entry in cost2_sd_list:
    np.savetxt(cost2_sd_file,entry)
  cost2_sd_file.close()
  cost3_sd_file=open("cost3_sd_file.txt","w")
  for entry in cost3_sd_list:
    np.savetxt(cost3_sd_file,entry)
  cost3_sd_file.close()
  cost4_sd_file=open("cost4_sd_file.txt","w")
  for entry in cost4_sd_list:
    np.savetxt(cost4_sd_file,entry)
  cost4_sd_file.close()


def fake_outer_loop(theta1,theta2,theta3,theta4,initial_state,num_trials,constraint_empirical1,constraint_empirical2,constraint_empirical3,constraint_empirical4):
  i=0
  iterations=10
  lam_E=0.2
  omega_E=np.mat([1.5,-1.5,1.5,-1.5]).T
  while i<iterations:
    print('iteration={}' .format(i))
    policy1=calculate_soft_policy(omega_E,theta1,lam_E,gamma,num_action)
    policy2=calculate_soft_policy(omega_E,theta2,lam_E,gamma,num_action)
    policy3=calculate_soft_policy(omega_E,theta3,lam_E,gamma,num_action)
    policy4=calculate_soft_policy(omega_E,theta4,lam_E,gamma,num_action)
    constraint1=constraint_expectation(policy1,initial_state,num_trials,lam_E)
    constraint2=constraint_expectation(policy2,initial_state,num_trials,lam_E)
    constraint3=constraint_expectation(policy3,initial_state,num_trials,lam_E)
    constraint4=constraint_expectation(policy4,initial_state,num_trials,lam_E)

    gradient1=constraint1-constraint_empirical1
    gradient2=constraint2-constraint_empirical2
    gradient3=constraint3-constraint_empirical3
    gradient4=constraint4-constraint_empirical4

    print('gradient1 is {}' .format(gradient1))
    print('gradient2 is {}' .format(gradient2))
    print('gradient3 is {}' .format(gradient3))
    print('gradient4 is {}' .format(gradient4))

    if i%2==1:
      theta1=0.5*(theta1+theta2)-(1.0/12000)*gradient1
      theta2=0.5*(theta1+theta2)-(1.0/12000)*gradient2
      theta3=0.5*(theta3+theta4)-(1.0/12000)*gradient3
      theta4=0.5*(theta3+theta4)-(1.0/12000)*gradient4

    else:
      theta1=0.5*(theta1+theta4)-(1.0/12000)*gradient1
      theta2=0.5*(theta2+theta3)-(1.0/12000)*gradient2
      theta3=0.5*(theta2+theta3)-(1.0/12000)*gradient3
      theta4=0.5*(theta1+theta4)-(1.0/12000)*gradient4

    print('theta1 is {}' .format(theta1))
    print('theta2 is {}' .format(theta2))
    print('theta3 is {}' .format(theta3))
    print('theta4 is {}' .format(theta4))

  i=i+1

  

initial_state=np.mat([0,8,0,4,0,0]).T
num_action=9
gamma=0.9
lam_E=0.2
omega_E=np.mat([1.5,-1.5,1.5,-1.5]).T
theta_E=np.mat([1.0,1.0,1.0,1.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0]).T

num_trials=100
num_data1=10
num_data2=20
num_data3=30
num_data4=40

a=np.loadtxt("optimal_trajectory_file.txt",dtype=float)
trajectories=a.reshape(30*num_trials,9)

trajectories1=trajectories[0:30*num_data1,:]
trajectories2=trajectories[30*num_data1:30*(num_data1+num_data2),:]
trajectories3=trajectories[30*(num_data1+num_data2):30*(num_data1+num_data2+num_data3),:]
trajectories4=trajectories[30*(num_data1+num_data2+num_data3):30*(num_data1+num_data2+num_data3+num_data4),:]

feature_empirical1=empirical_feature_counts(trajectories1,num_data1)
feature_empirical2=empirical_feature_counts(trajectories2,num_data2)
feature_empirical3=empirical_feature_counts(trajectories3,num_data3)
feature_empirical4=empirical_feature_counts(trajectories4,num_data4)

cost_empirical1=empirical_cost_counts(trajectories1,num_data1,theta_E)
cost_empirical2=empirical_cost_counts(trajectories2,num_data2,theta_E)
cost_empirical3=empirical_cost_counts(trajectories3,num_data3,theta_E)
cost_empirical4=empirical_cost_counts(trajectories4,num_data4,theta_E)

constraint_empirical1=empirical_constraint_counts(trajectories1,num_data1,lam_E)
constraint_empirical2=empirical_constraint_counts(trajectories2,num_data2,lam_E)
constraint_empirical3=empirical_constraint_counts(trajectories3,num_data3,lam_E)
constraint_empirical4=empirical_constraint_counts(trajectories4,num_data4,lam_E)

omega1=np.mat([0.0,0.0,0.0,0.0,0.0,0.0]).T
omega2=np.mat([0.025,-0.025,0.025,-0.025,0.025,-0.025]).T
omega3=np.mat([0.05,-0.05,0.05,-0.05,0.05,-0.05]).T
omega4=np.mat([0.075,-0.075,0.075,-0.075,0.075,-0.075]).T
lam1=0.0
lam2=0.001
lam3=0.002
lam4=0.003


#last_interation_inner_loop(initial_state,omega1,omega2,omega3,omega4,lam1,lam2,lam3,lam4,theta_E,gamma,num_action,num_trials,cost_empirical1,cost_empirical2,cost_empirical3,cost_empirical4, feature_empirical1,feature_empirical2,feature_empirical3,feature_empirical4,num_data1,num_data2,num_data3,num_data4)
#inner_loop(initial_state,omega1,omega2,omega3,omega4,lam1,lam2,lam3,lam4,theta_E,gamma,num_action,num_trials,cost_empirical1,cost_empirical2,cost_empirical3,cost_empirical4, feature_empirical1,feature_empirical2,feature_empirical3,feature_empirical4,num_data1,num_data2,num_data3,num_data4)
#theta=np.mat([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]).T
#centralized_outer_loop(theta,initial_state,num_trials,constraint_empirical1,constraint_empirical2,constraint_empirical3,constraint_empirical4,num_data1,num_data2,num_data3,num_data4)


#saved_policy=policy.reshape(9*9*9,9*9*9)
#soft_policy_file=open("bellman_policy_file.txt","w")

#for entry in saved_policy:
#  np.savetxt(soft_policy_file,entry)
#soft_policy_file.close()
#distribution=np.loadtxt("bellman_policy_file.txt",dtype=float)
#policy=distribution.reshape(9,9,9,9,num_action,num_action)
#trajectory=trial(initial_state,policy,num_action)
#print(trajectory)




distribution1=np.loadtxt("optimal_policy1_file.txt",dtype=float)
expert_policy1=distribution1.reshape(9,9,num_action)
distribution2=np.loadtxt("optimal_policy2_file.txt",dtype=float)
expert_policy2=distribution2.reshape(9,9,num_action)
distribution3=np.loadtxt("optimal_policy3_file.txt",dtype=float)
expert_policy3=distribution3.reshape(9,9,num_action)

feature, reward_mean, reward_sd, cost, cost_sd=feature_cost_mean_sd(expert_policy1,expert_policy2,expert_policy3,initial_state,num_trials,theta_E)
print(feature)
print(reward_mean)
print(reward_sd)
print(cost)
print(cost_sd)


















